{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 如何在聊天模型中使用few-shotting示例\n",
    "\n",
    "- 本指南介绍了如何使用示例输入和输出提示聊天模型。为模型提供一些这样的示例称为few-shotting，这是一种简单而强大的方法来指导生成，在某些情况下可以大大提高模型性能。\n",
    "- “few-shotting”常见于机器学习领域，指的是在训练模型时，只使用少量的有标注数据进行学习和训练的方法。这种方法与传统的需要大量数据进行训练的方式有所不同，旨在通过有限的示例让模型获得一定的泛化能力和预测能力。\n",
    "- few-shotting提示模板的目标是根据输入动态选择示例，然后在最终提示中格式化示例以提供模型。\n",
    "- 对于如何最好地进行few-shotting提示似乎没有可靠的共识，最佳提示编译可能会因模型而异。因此，langchain提供了像FewShotChatMessagePromptTemplate这样的few-shotting提示模板作为一个灵活的起点，你可以根据需要修改或替换它们。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Note: you may need to restart the kernel to use updated packages.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "[notice] A new release of pip is available: 23.1.2 -> 24.1.2\n",
      "[notice] To update, run: python.exe -m pip install --upgrade pip\n"
     ]
    }
   ],
   "source": [
    "%pip install -qU langchain langchain_openai langchain_chroma"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- 下面是一个简单的演示。首先，定义您想包含的示例。让我们给LLM一个不熟悉的数学运算符，用“🦜”表情符号表示："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'The expression \"2 🦜 9\" is not a standard mathematical operation or equation. It appears to be a combination of the number 2 and the parrot emoji 🦜 followed by the number 9. It does not have a specific mathematical meaning.'"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from langchain_openai import ChatOpenAI\n",
    "from langchain_core.output_parsers import StrOutputParser\n",
    "\n",
    "model = ChatOpenAI(model=\"gpt-3.5-turbo-0125\", temperature=0.0)\n",
    "\n",
    "chain = model | StrOutputParser()\n",
    "\n",
    "chain.invoke(\"What is 2 🦜 9?\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- 现在让我们看看如果我们给LLM一些示例来使用会发生什么。我们将在下面定义一些："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "examples = [\n",
    "    {\"input\": \"2 🦜 2\", \"output\": \"4\"},\n",
    "    {\"input\": \"2 🦜 3\", \"output\": \"5\"},\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- 接下来，将它们组装成few-shot提示模板。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[HumanMessage(content='2 🦜 2'), AIMessage(content='4'), HumanMessage(content='2 🦜 3'), AIMessage(content='5')]\n"
     ]
    }
   ],
   "source": [
    "from langchain_core.prompts import ChatPromptTemplate, FewShotChatMessagePromptTemplate\n",
    "\n",
    "# This is a prompt template used to format each individual example.\n",
    "example_prompt = ChatPromptTemplate.from_messages(\n",
    "    [\n",
    "        (\"human\", \"{input}\"),\n",
    "        (\"ai\", \"{output}\"),\n",
    "    ]\n",
    ")\n",
    "few_shot_prompt = FewShotChatMessagePromptTemplate(\n",
    "    # 官方给的示例中没有这个参数，导致运行时一直报错，实在是太坑了。\n",
    "    input_variables=[\"input\"],\n",
    "    example_prompt=example_prompt,\n",
    "    examples=examples,\n",
    ")\n",
    "\n",
    "print(few_shot_prompt.invoke({\"input\": \"\"}).to_messages())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'11'"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "final_prompt = ChatPromptTemplate.from_messages(\n",
    "    [\n",
    "        (\"system\", \"You are a wondrous wizard of math.\"),\n",
    "        few_shot_prompt,\n",
    "        (\"human\", \"{input}\"),\n",
    "    ]\n",
    ")\n",
    "\n",
    "chain = final_prompt | model | StrOutputParser()\n",
    "\n",
    "chain.invoke({\"input\": \"What is 2 🦜 9?\"})"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
